{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# (beta) Building a Simple CPU Performance Profiler with FX\n**Author**: [James Reed](https://github.com/jamesr66a)\n\nIn this tutorial, we are going to use FX to do the following:\n\n1) Capture PyTorch Python code in a way that we can inspect and gather\n   statistics about the structure and execution of the code\n2) Build out a small class that will serve as a simple performance \"profiler\",\n   collecting runtime statistics about each part of the model from actual\n   runs.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For this tutorial, we are going to use the torchvision ResNet18 model\nfor demonstration purposes.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch.fx\nimport torchvision.models as models\n\nrn18 = models.resnet18()\nrn18.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now that we have our model, we want to inspect deeper into its\nperformance. That is, for the following invocation, which parts\nof the model are taking the longest?\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "input = torch.randn(5, 3, 224, 224)\noutput = rn18(input)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A common way of answering that question is to go through the program\nsource, add code that collects timestamps at various points in the\nprogram, and compare the difference between those timestamps to see\nhow long the regions between the timestamps take.\n\nThat technique is certainly applicable to PyTorch code, however it\nwould be nicer if we didn't have to copy over model code and edit it,\nespecially code we haven't written (like this torchvision model).\nInstead, we are going to use FX to automate this \"instrumentation\"\nprocess without needing to modify any source.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's get some imports out of the way (we will be using all\nof these later in the code).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import statistics, tabulate, time\nfrom typing import Any, Dict, List\nfrom torch.fx import Interpreter"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>``tabulate`` is an external library that is not a dependency of PyTorch.\n    We will be using it to more easily visualize performance data. Please\n    make sure you've installed it from your favorite Python package source.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Capturing the Model with Symbolic Tracing\nNext, we are going to use FX's symbolic tracing mechanism to capture\nthe definition of our model in a data structure we can manipulate\nand examine.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "traced_rn18 = torch.fx.symbolic_trace(rn18)\nprint(traced_rn18.graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This gives us a Graph representation of the ResNet18 model. A Graph\nconsists of a series of Nodes connected to each other. Each Node\nrepresents a call-site in the Python code (whether to a function,\na module, or a method) and the edges (represented as ``args`` and ``kwargs``\non each node) represent the values passed between these call-sites. More\ninformation about the Graph representation and the rest of FX's APIs ca\nbe found at the FX documentation https://pytorch.org/docs/master/fx.html.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Creating a Profiling Interpreter\nNext, we are going to create a class that inherits from ``torch.fx.Interpreter``.\nThough the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code\nthat is run when you call a ``GraphModule``, an alternative way to run a\n``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is\nthe functionality that ``Interpreter`` provides: It interprets the graph node-\nby-node.\n\nBy inheriting from ``Interpreter``, we can override various functionality and\ninstall the profiling behavior we want. The goal is to have an object to which\nwe can pass a model, invoke the model 1 or more times, then get statistics about\nhow long the model and each part of the model took during those runs.\n\nLet's define our ``ProfilingInterpreter`` class:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ProfilingInterpreter(Interpreter):\n    def __init__(self, mod : torch.nn.Module):\n        # Rather than have the user symbolically trace their model,\n        # we're going to do it in the constructor. As a result, the\n        # user can pass in any ``Module`` without having to worry about\n        # symbolic tracing APIs\n        gm = torch.fx.symbolic_trace(mod)\n        super().__init__(gm)\n\n        # We are going to store away two things here:\n        #\n        # 1. A list of total runtimes for ``mod``. In other words, we are\n        #    storing away the time ``mod(...)`` took each time this\n        #    interpreter is called.\n        self.total_runtime_sec : List[float] = []\n        # 2. A map from ``Node`` to a list of times (in seconds) that\n        #    node took to run. This can be seen as similar to (1) but\n        #    for specific sub-parts of the model.\n        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}\n\n    ######################################################################\n    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``\n    # method is the top-level entrypoint for execution of the model. We will\n    # want to intercept this so that we can record the total runtime of the\n    # model.\n\n    def run(self, *args) -> Any:\n        # Record the time we started running the model\n        t_start = time.time()\n        # Run the model by delegating back into Interpreter.run()\n        return_val = super().run(*args)\n        # Record the time we finished running the model\n        t_end = time.time()\n        # Store the total elapsed time this model execution took in the\n        # ProfilingInterpreter\n        self.total_runtime_sec.append(t_end - t_start)\n        return return_val\n\n    ######################################################################\n    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each\n    # time it executes a single node. We will intercept this so that we\n    # can measure and record the time taken for each individual call in\n    # the model.\n\n    def run_node(self, n : torch.fx.Node) -> Any:\n        # Record the time we started running the op\n        t_start = time.time()\n        # Run the op by delegating back into Interpreter.run_node()\n        return_val = super().run_node(n)\n        # Record the time we finished running the op\n        t_end = time.time()\n        # If we don't have an entry for this node in our runtimes_sec\n        # data structure, add one with an empty list value.\n        self.runtimes_sec.setdefault(n, [])\n        # Record the total elapsed time for this single invocation\n        # in the runtimes_sec data structure\n        self.runtimes_sec[n].append(t_end - t_start)\n        return return_val\n\n    ######################################################################\n    # Finally, we are going to define a method (one which doesn't override\n    # any ``Interpreter`` method) that provides us a nice, organized view of\n    # the data we have collected.\n\n    def summary(self, should_sort : bool = False) -> str:\n        # Build up a list of summary information for each node\n        node_summaries : List[List[Any]] = []\n        # Calculate the mean runtime for the whole network. Because the\n        # network may have been called multiple times during profiling,\n        # we need to summarize the runtimes. We choose to use the\n        # arithmetic mean for this.\n        mean_total_runtime = statistics.mean(self.total_runtime_sec)\n\n        # For each node, record summary statistics\n        for node, runtimes in self.runtimes_sec.items():\n            # Similarly, compute the mean runtime for ``node``\n            mean_runtime = statistics.mean(runtimes)\n            # For easier understanding, we also compute the percentage\n            # time each node took with respect to the whole network.\n            pct_total = mean_runtime / mean_total_runtime * 100\n            # Record the node's type, name of the node, mean runtime, and\n            # percent runtim\n            node_summaries.append(\n                [node.op, str(node), mean_runtime, pct_total])\n\n        # One of the most important questions to answer when doing performance\n        # profiling is \"Which op(s) took the longest?\". We can make this easy\n        # to see by providing sorting functionality in our summary view\n        if should_sort:\n            node_summaries.sort(key=lambda s: s[2], reverse=True)\n\n        # Use the ``tabulate`` library to create a well-formatted table\n        # presenting our summary information\n        headers : List[str] = [\n            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'\n        ]\n        return tabulate.tabulate(node_summaries, headers=headers)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-info\"><h4>Note</h4><p>We use Python's ``time.time`` function to pull wall clock\n      timestamps and compare them. This is not the most accurate\n      way to measure performance, and will only give us a first-\n      order approximation. We use this simple technique only for the\n      purpose of demonstration in this tutorial.</p></div>\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Investigating the Performance of ResNet18\nWe can now use ``ProfilingInterpreter`` to inspect the performance\ncharacteristics of our ResNet18 model;\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "interp = ProfilingInterpreter(rn18)\ninterp.run(input)\nprint(interp.summary(True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "There are two things we should call out here:\n\n* MaxPool2d takes up the most time. This is a known issue:\n  https://github.com/pytorch/pytorch/issues/51393\n* BatchNorm2d also takes up significant time. We can continue this\n  line of thinking and optimize this in the Conv-BN Fusion with FX\n  [tutorial](https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html). \n\n\n## Conclusion\nAs we can see, using FX we can easily capture PyTorch programs (even\nones we don't have the source code for!) in a machine-interpretable\nformat and use that for analysis, such as the performance analysis\nwe've done here. FX opens up an exiciting world of possibilities for\nworking with PyTorch programs.\n\nFinally, since FX is still in beta, we would be happy to hear any\nfeedback you have about using it. Please feel free to use the\nPyTorch Forums (https://discuss.pytorch.org/) and the issue tracker\n(https://github.com/pytorch/pytorch/issues) to provide any feedback\nyou might have.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}